Building neural network & competitor

Published

February 9, 2025

Packages

library(discrim)
library(keras3)
library(tensorflow)
library(tidymodels)

Data

General prep

set.seed(42)

spam <- readr::read_csv(here::here("data/spam.csv"))

spam <- 
  spam |> 
  mutate(
    # outcome has to be ordered factor for tidymodels:
    spam = factor(
      if_else(spam == 0, "no spam", "spam"),
      ordered = TRUE,
      levels = c("spam", "no spam")
    )
  )

# Data split (60/20/20):
spam_split <- initial_validation_split(spam, prop = c(0.6, 0.2), strata = "spam")

train <- training(spam_split)
val <- validation(spam_split)
test <- testing(spam_split)

Getting an overview:

glimpse(train)
Rows: 2,759
Columns: 58
$ word_freq_make             <dbl> 0.00, 0.21, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_address          <dbl> 0.64, 0.28, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_all              <dbl> 0.64, 0.50, 0.71, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_3d               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_our              <dbl> 0.32, 0.14, 1.23, 0.63, 0.63, 1.92, 1.88, 0…
$ word_freq_over             <dbl> 0.00, 0.28, 0.19, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_remove           <dbl> 0.00, 0.21, 0.19, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_internet         <dbl> 0.00, 0.07, 0.12, 0.63, 0.63, 0.00, 1.88, 0…
$ word_freq_order            <dbl> 0.00, 0.00, 0.64, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_mail             <dbl> 0.00, 0.94, 0.25, 0.63, 0.63, 0.64, 0.00, 0…
$ word_freq_receive          <dbl> 0.00, 0.21, 0.38, 0.31, 0.31, 0.96, 0.00, 0…
$ word_freq_will             <dbl> 0.64, 0.79, 0.45, 0.31, 0.31, 1.28, 0.00, 0…
$ word_freq_people           <dbl> 0.00, 0.65, 0.12, 0.31, 0.31, 0.00, 0.00, 0…
$ word_freq_report           <dbl> 0.00, 0.21, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_addresses        <dbl> 0.00, 0.14, 1.75, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_free             <dbl> 0.32, 0.14, 0.06, 0.31, 0.31, 0.96, 0.00, 0…
$ word_freq_business         <dbl> 0.00, 0.07, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_email            <dbl> 1.29, 0.28, 1.03, 0.00, 0.00, 0.32, 0.00, 0…
$ word_freq_you              <dbl> 1.93, 3.47, 1.36, 3.18, 3.18, 3.85, 0.00, 1…
$ word_freq_credit           <dbl> 0.00, 0.00, 0.32, 0.00, 0.00, 0.00, 0.00, 3…
$ word_freq_your             <dbl> 0.96, 1.59, 0.51, 0.31, 0.31, 0.64, 0.00, 2…
$ word_freq_font             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_000              <dbl> 0.00, 0.43, 1.16, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_money            <dbl> 0.00, 0.43, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_hp               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_hpl              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_george           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_650              <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_lab              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_labs             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_telnet           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_857              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_data             <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_415              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_85               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_technology       <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_1999             <dbl> 0.00, 0.07, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_parts            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_pm               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_direct           <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_cs               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_meeting          <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_original         <dbl> 0.00, 0.00, 0.12, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_project          <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_re               <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_edu              <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0…
$ word_freq_table            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ word_freq_conference       <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ `char_freq_;`              <dbl> 0.000, 0.000, 0.010, 0.000, 0.000, 0.000, 0…
$ `char_freq_(`              <dbl> 0.000, 0.132, 0.143, 0.137, 0.135, 0.054, 0…
$ `char_freq_[`              <dbl> 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0…
$ `char_freq_!`              <dbl> 0.778, 0.372, 0.276, 0.137, 0.135, 0.164, 0…
$ `char_freq_$`              <dbl> 0.000, 0.180, 0.184, 0.000, 0.000, 0.054, 0…
$ `char_freq_#`              <dbl> 0.000, 0.048, 0.010, 0.000, 0.000, 0.000, 0…
$ capital_run_length_average <dbl> 3.756, 5.114, 9.821, 3.537, 3.537, 1.671, 2…
$ capital_run_length_longest <dbl> 61, 101, 485, 40, 40, 4, 11, 445, 43, 24, 5…
$ capital_run_length_total   <dbl> 278, 1028, 2259, 191, 191, 112, 49, 1257, 7…
$ spam                       <ord> spam, spam, spam, spam, spam, spam, spam, s…

A lot of these seem to be word or character frequencies, so I suspect that they might be sparse (a lot of zero values) and have skewed distributions. Investigating:

train |> 
  select(-spam) |> 
  pivot_longer(cols = everything(), names_to = "Feature", values_to = "Value") |> 
  ggplot(aes(x = Value)) +
  geom_histogram(bins = 20, fill = "steelblue", color = "white") +
  facet_wrap(~ Feature, scales = "free") + 
  theme_minimal()

Preprocessing

Fitting the recipe to the training data (to avoid leakage). Synthetic minority class oversampling, to get the same amount of examples for both classes, log-transforming, normalizing, dropping highly correlated features & those with near-zero variance:

spam_rec <- 
  recipe(spam ~ ., data = train) |> 
  themis::step_smote(spam, over_ratio = 1, neighbors = 5) |> 
  step_log(all_numeric_predictors(), offset = 1) |> 
  step_range(all_numeric_predictors(), min = 0, max = 1) |> 
  step_corr(all_numeric_predictors(), threshold = 0.9) |> 
  step_nzv(all_numeric_predictors())

EDA

Class distribution:

train |> 
  count(spam) |> 
  ggplot(aes(x = spam, y = n, color = spam, fill = spam)) +
  geom_col(alpha = 0.7) +
  theme_minimal() +
  scale_color_brewer(palette = "Dark2", direction = -1) + 
  scale_fill_brewer(palette = "Dark2", direction = -1) +
  labs(
    title = "Class Distribution",
    x = "", 
    y = "N. of obs"
  ) +
  theme(legend.position = "none")

Feature correlations (unlabelled, hard to see with this many features anyways, just to check whether highly correlated features exist):

corrs <- 
  train |> 
  select(-spam) |> 
  cor() |> 
  as.data.frame() |> 
  rownames_to_column(var = "x1") |> 
  tibble() |> 
  pivot_longer(-x1, names_to = "x2", values_to = "val")

corrs |> 
  ggplot(aes(x = x1, y = x2, fill = val)) +
  geom_tile() +
  scale_fill_distiller(palette = "RdYlGn", direction = 1, limits = c(-1, 1)) +
  theme(axis.text = element_blank(), axis.ticks = element_blank()) +
  labs(
    title = "Pairwise correlations", 
    subtitle = "All features", 
    fill = "Pearson's r", 
    x = "",
    y = ""
  )

PCA: checking if we can find some patterns in a transformed representation (also seeing if PCA as part of the preprocessing pipeline might make sense)[^pca]:

spam_rec |> 
  step_pca(all_numeric_predictors(), num_comp = 4) |> 
  prep() |> 
  bake(new_data = train) |> 
  ggplot(aes(x = .panel_x, y = .panel_y, color = spam, fill = spam)) +
  geom_point(alpha = 0.4, size = 0.5) +
  ggforce::geom_autodensity(alpha = .3) +
  ggforce::facet_matrix(vars(-spam), layer.diag = 2) + 
  scale_color_brewer(palette = "Dark2", direction = -1) + 
  scale_fill_brewer(palette = "Dark2", direction = -1) +
  theme_minimal() +
  labs(title = "Principal Component Analysis", fill = "", color = "")

[^pca] I tried this (it was tempting to just magically generate richt & uncorrelated features instead of dealing with the existing ones, but that would make the neural network more unstable)

Neural Network Classifier

Preparing the data (separate features & labels, bring into matrix format). We also need to apply the fitted preprocessing pipeline here for the data going into keras. prep() fits the recipe (on the training data, we specified this in the recipe), and bake() applies the transformation (equivalent to .fit() and .transform() in sklearn pipelines):

keras_split <- function(set) {
  df <- 
    set |> 
    mutate(spam = if_else(spam == "spam", 1, 0))
  
  list(
    X = df |> select(-spam) |> as.matrix() |> unname(),
    y = df |> pull(spam) |> as.matrix()
  )
}

keras_train <- spam_rec |> prep() |> bake(new_data = train) |> keras_split()
keras_val <-  spam_rec |> prep() |> bake(new_data = val) |> keras_split()
keras_test <- spam_rec |> prep() |> bake(new_data = test) |> keras_split()

X_train <- keras_train$X
y_train <- keras_train$y
X_val <- keras_val$X
y_val <- keras_val$y
X_test <- keras_test$X
y_test <- keras_test$y

Model:

keras3::set_random_seed(42) 
#^ the keras3 version (supposedly) sets a seed for the R session and the whole backend

mlp <- keras_model_sequential(
  layers = list(
    layer_dense(units = 128, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
    layer_dropout(rate = 0.25),
    layer_dense(units = 64, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
    layer_dropout(rate = 0.25),
    layer_dense(units = 32, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
    layer_dropout(rate = 0.25),
    layer_dense(units = 16, activation = "relu", kernel_regularizer = regularizer_l2(0.001)),
    layer_dropout(rate = 0.25),
    layer_dense(units = 1, activation = "sigmoid")
  )
)

Compiling:

keras3::set_random_seed(42)
# chunks are independent when rendering apparently, so again...

mlp |> 
  compile(
    optimizer = optimizer_adam(learning_rate = 0.001),
    loss = "binary_crossentropy",
    metrics = list(
      metric_binary_accuracy(),
      metric_precision(),
      metric_recall(),
      metric_auc()
    )
  )

Training:

keras3::set_random_seed(42)

history <- 
  mlp |> 
  fit(
    x = X_train,
    y = y_train,
    epochs = 250L,
    batch_size = 32L,
    validation_data = list(X_val, y_val),
    callbacks = list(
      # early stopping:
      callback_early_stopping(
        monitor = "val_loss",
        patience = 5L,
        restore_best_weights = TRUE 
      ),
      # schedule learning rate:
      callback_reduce_lr_on_plateau(
        monitor = "val_loss",
        factor = 0.8,
        patience = 3L,
        min_lr = 0.00001
      )
    ),
    shuffle = FALSE
  )
Epoch 1/250
87/87 - 6s - 73ms/step - auc: 0.6983 - binary_accuracy: 0.5640 - loss: 0.9687 - precision: 0.4744 - recall: 0.9871 - val_auc: 0.2853 - val_binary_accuracy: 0.6059 - val_loss: 0.8075 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 2/250
87/87 - 0s - 3ms/step - auc: 0.3105 - binary_accuracy: 0.5143 - loss: 0.8138 - precision: 0.3364 - recall: 0.2392 - val_auc: 0.8465 - val_binary_accuracy: 0.6243 - val_loss: 0.7709 - val_precision: 1.0000 - val_recall: 0.0468 - learning_rate: 0.0010
Epoch 3/250
87/87 - 0s - 3ms/step - auc: 0.7258 - binary_accuracy: 0.6531 - loss: 0.7476 - precision: 0.5518 - recall: 0.6375 - val_auc: 0.6754 - val_binary_accuracy: 0.6059 - val_loss: 0.7395 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - learning_rate: 0.0010
Epoch 4/250
87/87 - 0s - 3ms/step - auc: 0.7038 - binary_accuracy: 0.6876 - loss: 0.7083 - precision: 0.6048 - recall: 0.5971 - val_auc: 0.8798 - val_binary_accuracy: 0.6384 - val_loss: 0.6782 - val_precision: 0.9688 - val_recall: 0.0854 - learning_rate: 0.0010
Epoch 5/250
87/87 - 0s - 2ms/step - auc: 0.8246 - binary_accuracy: 0.7746 - loss: 0.5927 - precision: 0.7313 - recall: 0.6762 - val_auc: 0.8903 - val_binary_accuracy: 0.6384 - val_loss: 0.8259 - val_precision: 0.9688 - val_recall: 0.0854 - learning_rate: 0.0010
Epoch 6/250
87/87 - 0s - 2ms/step - auc: 0.8152 - binary_accuracy: 0.7952 - loss: 0.5887 - precision: 0.7490 - recall: 0.7222 - val_auc: 0.9451 - val_binary_accuracy: 0.8187 - val_loss: 0.5009 - val_precision: 0.9537 - val_recall: 0.5675 - learning_rate: 0.0010
Epoch 7/250
87/87 - 0s - 3ms/step - auc: 0.9169 - binary_accuracy: 0.8590 - loss: 0.4420 - precision: 0.8428 - recall: 0.7893 - val_auc: 0.9476 - val_binary_accuracy: 0.8208 - val_loss: 0.5265 - val_precision: 0.9541 - val_recall: 0.5730 - learning_rate: 0.0010
Epoch 8/250
87/87 - 0s - 3ms/step - auc: 0.9176 - binary_accuracy: 0.8731 - loss: 0.4301 - precision: 0.8428 - recall: 0.8335 - val_auc: 0.9543 - val_binary_accuracy: 0.8469 - val_loss: 0.4732 - val_precision: 0.9440 - val_recall: 0.6501 - learning_rate: 0.0010
Epoch 9/250
87/87 - 0s - 3ms/step - auc: 0.9288 - binary_accuracy: 0.8847 - loss: 0.4070 - precision: 0.8583 - recall: 0.8473 - val_auc: 0.9569 - val_binary_accuracy: 0.8719 - val_loss: 0.4453 - val_precision: 0.9487 - val_recall: 0.7135 - learning_rate: 0.0010
Epoch 10/250
87/87 - 0s - 2ms/step - auc: 0.9405 - binary_accuracy: 0.8942 - loss: 0.3826 - precision: 0.8664 - recall: 0.8648 - val_auc: 0.9586 - val_binary_accuracy: 0.8762 - val_loss: 0.4344 - val_precision: 0.9495 - val_recall: 0.7245 - learning_rate: 0.0010
Epoch 11/250
87/87 - 0s - 4ms/step - auc: 0.9368 - binary_accuracy: 0.8949 - loss: 0.3866 - precision: 0.8666 - recall: 0.8666 - val_auc: 0.9601 - val_binary_accuracy: 0.8806 - val_loss: 0.4214 - val_precision: 0.9470 - val_recall: 0.7383 - learning_rate: 0.0010
Epoch 12/250
87/87 - 0s - 4ms/step - auc: 0.9441 - binary_accuracy: 0.8956 - loss: 0.3693 - precision: 0.8682 - recall: 0.8666 - val_auc: 0.9612 - val_binary_accuracy: 0.8871 - val_loss: 0.4229 - val_precision: 0.9481 - val_recall: 0.7548 - learning_rate: 0.0010
Epoch 13/250
87/87 - 0s - 2ms/step - auc: 0.9453 - binary_accuracy: 0.9007 - loss: 0.3653 - precision: 0.8726 - recall: 0.8758 - val_auc: 0.9627 - val_binary_accuracy: 0.8893 - val_loss: 0.4111 - val_precision: 0.9485 - val_recall: 0.7603 - learning_rate: 0.0010
Epoch 14/250
87/87 - 0s - 3ms/step - auc: 0.9448 - binary_accuracy: 0.9007 - loss: 0.3633 - precision: 0.8712 - recall: 0.8776 - val_auc: 0.9631 - val_binary_accuracy: 0.8893 - val_loss: 0.4048 - val_precision: 0.9454 - val_recall: 0.7631 - learning_rate: 0.0010
Epoch 15/250
87/87 - 0s - 3ms/step - auc: 0.9465 - binary_accuracy: 0.9069 - loss: 0.3575 - precision: 0.8843 - recall: 0.8786 - val_auc: 0.9641 - val_binary_accuracy: 0.8849 - val_loss: 0.4057 - val_precision: 0.9446 - val_recall: 0.7521 - learning_rate: 0.0010
Epoch 16/250
87/87 - 0s - 4ms/step - auc: 0.9477 - binary_accuracy: 0.9014 - loss: 0.3563 - precision: 0.8798 - recall: 0.8684 - val_auc: 0.9651 - val_binary_accuracy: 0.8871 - val_loss: 0.3966 - val_precision: 0.9450 - val_recall: 0.7576 - learning_rate: 0.0010
Epoch 17/250
87/87 - 0s - 3ms/step - auc: 0.9515 - binary_accuracy: 0.9032 - loss: 0.3451 - precision: 0.8789 - recall: 0.8749 - val_auc: 0.9658 - val_binary_accuracy: 0.8882 - val_loss: 0.3924 - val_precision: 0.9452 - val_recall: 0.7603 - learning_rate: 0.0010
Epoch 18/250
87/87 - 0s - 2ms/step - auc: 0.9560 - binary_accuracy: 0.9087 - loss: 0.3285 - precision: 0.8848 - recall: 0.8832 - val_auc: 0.9653 - val_binary_accuracy: 0.8925 - val_loss: 0.3971 - val_precision: 0.9430 - val_recall: 0.7741 - learning_rate: 0.0010
Epoch 19/250
87/87 - 0s - 2ms/step - auc: 0.9521 - binary_accuracy: 0.9108 - loss: 0.3363 - precision: 0.8876 - recall: 0.8859 - val_auc: 0.9660 - val_binary_accuracy: 0.8893 - val_loss: 0.3936 - val_precision: 0.9454 - val_recall: 0.7631 - learning_rate: 0.0010
Epoch 20/250
87/87 - 0s - 2ms/step - auc: 0.9515 - binary_accuracy: 0.9130 - loss: 0.3328 - precision: 0.8875 - recall: 0.8924 - val_auc: 0.9653 - val_binary_accuracy: 0.8817 - val_loss: 0.4040 - val_precision: 0.9472 - val_recall: 0.7410 - learning_rate: 0.0010
Epoch 21/250
87/87 - 0s - 3ms/step - auc: 0.9467 - binary_accuracy: 0.9047 - loss: 0.3502 - precision: 0.8836 - recall: 0.8730 - val_auc: 0.9684 - val_binary_accuracy: 0.9088 - val_loss: 0.3405 - val_precision: 0.9401 - val_recall: 0.8209 - learning_rate: 8.0000e-04
Epoch 22/250
87/87 - 0s - 3ms/step - auc: 0.9536 - binary_accuracy: 0.9163 - loss: 0.3249 - precision: 0.8970 - recall: 0.8896 - val_auc: 0.9679 - val_binary_accuracy: 0.9077 - val_loss: 0.3401 - val_precision: 0.9399 - val_recall: 0.8182 - learning_rate: 8.0000e-04
Epoch 23/250
87/87 - 0s - 3ms/step - auc: 0.9530 - binary_accuracy: 0.9134 - loss: 0.3303 - precision: 0.8911 - recall: 0.8887 - val_auc: 0.9680 - val_binary_accuracy: 0.9034 - val_loss: 0.3461 - val_precision: 0.9419 - val_recall: 0.8044 - learning_rate: 8.0000e-04
Epoch 24/250
87/87 - 0s - 5ms/step - auc: 0.9528 - binary_accuracy: 0.9145 - loss: 0.3277 - precision: 0.8922 - recall: 0.8905 - val_auc: 0.9682 - val_binary_accuracy: 0.8990 - val_loss: 0.3511 - val_precision: 0.9412 - val_recall: 0.7934 - learning_rate: 8.0000e-04
Epoch 25/250
87/87 - 0s - 3ms/step - auc: 0.9508 - binary_accuracy: 0.9083 - loss: 0.3307 - precision: 0.8875 - recall: 0.8786 - val_auc: 0.9684 - val_binary_accuracy: 0.8979 - val_loss: 0.3520 - val_precision: 0.9410 - val_recall: 0.7906 - learning_rate: 8.0000e-04
Epoch 26/250
87/87 - 0s - 3ms/step - auc: 0.9504 - binary_accuracy: 0.9065 - loss: 0.3384 - precision: 0.8877 - recall: 0.8730 - val_auc: 0.9698 - val_binary_accuracy: 0.9110 - val_loss: 0.3156 - val_precision: 0.9404 - val_recall: 0.8264 - learning_rate: 6.4000e-04
Epoch 27/250
87/87 - 0s - 2ms/step - auc: 0.9557 - binary_accuracy: 0.9184 - loss: 0.3143 - precision: 0.8998 - recall: 0.8924 - val_auc: 0.9699 - val_binary_accuracy: 0.9110 - val_loss: 0.3128 - val_precision: 0.9404 - val_recall: 0.8264 - learning_rate: 6.4000e-04
Epoch 28/250
87/87 - 0s - 3ms/step - auc: 0.9567 - binary_accuracy: 0.9141 - loss: 0.3151 - precision: 0.8892 - recall: 0.8933 - val_auc: 0.9701 - val_binary_accuracy: 0.9110 - val_loss: 0.3167 - val_precision: 0.9404 - val_recall: 0.8264 - learning_rate: 6.4000e-04
Epoch 29/250
87/87 - 0s - 3ms/step - auc: 0.9556 - binary_accuracy: 0.9181 - loss: 0.3152 - precision: 0.8953 - recall: 0.8970 - val_auc: 0.9699 - val_binary_accuracy: 0.9066 - val_loss: 0.3187 - val_precision: 0.9397 - val_recall: 0.8154 - learning_rate: 6.4000e-04
Epoch 30/250
87/87 - 0s - 2ms/step - auc: 0.9560 - binary_accuracy: 0.9155 - loss: 0.3143 - precision: 0.8954 - recall: 0.8896 - val_auc: 0.9697 - val_binary_accuracy: 0.9045 - val_loss: 0.3198 - val_precision: 0.9393 - val_recall: 0.8099 - learning_rate: 6.4000e-04
Epoch 31/250
87/87 - 0s - 3ms/step - auc: 0.9580 - binary_accuracy: 0.9130 - loss: 0.3078 - precision: 0.8925 - recall: 0.8859 - val_auc: 0.9704 - val_binary_accuracy: 0.9153 - val_loss: 0.3051 - val_precision: 0.9412 - val_recall: 0.8375 - learning_rate: 5.1200e-04
Epoch 32/250
87/87 - 0s - 4ms/step - auc: 0.9572 - binary_accuracy: 0.9137 - loss: 0.3040 - precision: 0.8934 - recall: 0.8868 - val_auc: 0.9702 - val_binary_accuracy: 0.9142 - val_loss: 0.3027 - val_precision: 0.9410 - val_recall: 0.8347 - learning_rate: 5.1200e-04
Epoch 33/250
87/87 - 0s - 3ms/step - auc: 0.9581 - binary_accuracy: 0.9203 - loss: 0.3037 - precision: 0.8981 - recall: 0.8997 - val_auc: 0.9704 - val_binary_accuracy: 0.9121 - val_loss: 0.3027 - val_precision: 0.9406 - val_recall: 0.8292 - learning_rate: 5.1200e-04
Epoch 34/250
87/87 - 0s - 3ms/step - auc: 0.9577 - binary_accuracy: 0.9174 - loss: 0.3076 - precision: 0.8981 - recall: 0.8914 - val_auc: 0.9702 - val_binary_accuracy: 0.9099 - val_loss: 0.3059 - val_precision: 0.9403 - val_recall: 0.8237 - learning_rate: 5.1200e-04
Epoch 35/250
87/87 - 0s - 2ms/step - auc: 0.9577 - binary_accuracy: 0.9152 - loss: 0.3018 - precision: 0.8916 - recall: 0.8933 - val_auc: 0.9702 - val_binary_accuracy: 0.9110 - val_loss: 0.3028 - val_precision: 0.9404 - val_recall: 0.8264 - learning_rate: 5.1200e-04
Epoch 36/250
87/87 - 0s - 4ms/step - auc: 0.9554 - binary_accuracy: 0.9152 - loss: 0.3067 - precision: 0.8960 - recall: 0.8878 - val_auc: 0.9709 - val_binary_accuracy: 0.9121 - val_loss: 0.2926 - val_precision: 0.9352 - val_recall: 0.8347 - learning_rate: 4.0960e-04
Epoch 37/250
87/87 - 0s - 3ms/step - auc: 0.9578 - binary_accuracy: 0.9213 - loss: 0.2975 - precision: 0.9050 - recall: 0.8942 - val_auc: 0.9707 - val_binary_accuracy: 0.9131 - val_loss: 0.2923 - val_precision: 0.9381 - val_recall: 0.8347 - learning_rate: 4.0960e-04
Epoch 38/250
87/87 - 0s - 3ms/step - auc: 0.9569 - binary_accuracy: 0.9206 - loss: 0.3011 - precision: 0.9033 - recall: 0.8942 - val_auc: 0.9707 - val_binary_accuracy: 0.9121 - val_loss: 0.2894 - val_precision: 0.9352 - val_recall: 0.8347 - learning_rate: 4.0960e-04
Epoch 39/250
87/87 - 0s - 2ms/step - auc: 0.9637 - binary_accuracy: 0.9242 - loss: 0.2868 - precision: 0.9095 - recall: 0.8970 - val_auc: 0.9704 - val_binary_accuracy: 0.9110 - val_loss: 0.2952 - val_precision: 0.9377 - val_recall: 0.8292 - learning_rate: 4.0960e-04
Epoch 40/250
87/87 - 0s - 3ms/step - auc: 0.9599 - binary_accuracy: 0.9213 - loss: 0.2946 - precision: 0.9035 - recall: 0.8960 - val_auc: 0.9704 - val_binary_accuracy: 0.9121 - val_loss: 0.2947 - val_precision: 0.9406 - val_recall: 0.8292 - learning_rate: 4.0960e-04
Epoch 41/250
87/87 - 0s - 6ms/step - auc: 0.9595 - binary_accuracy: 0.9232 - loss: 0.2952 - precision: 0.9070 - recall: 0.8970 - val_auc: 0.9712 - val_binary_accuracy: 0.9121 - val_loss: 0.2944 - val_precision: 0.9406 - val_recall: 0.8292 - learning_rate: 4.0960e-04
Epoch 42/250
87/87 - 0s - 2ms/step - auc: 0.9594 - binary_accuracy: 0.9184 - loss: 0.2921 - precision: 0.9058 - recall: 0.8850 - val_auc: 0.9712 - val_binary_accuracy: 0.9142 - val_loss: 0.2852 - val_precision: 0.9356 - val_recall: 0.8402 - learning_rate: 3.2768e-04
Epoch 43/250
87/87 - 0s - 3ms/step - auc: 0.9622 - binary_accuracy: 0.9261 - loss: 0.2848 - precision: 0.9122 - recall: 0.8988 - val_auc: 0.9711 - val_binary_accuracy: 0.9142 - val_loss: 0.2850 - val_precision: 0.9356 - val_recall: 0.8402 - learning_rate: 3.2768e-04
Epoch 44/250
87/87 - 0s - 2ms/step - auc: 0.9602 - binary_accuracy: 0.9253 - loss: 0.2879 - precision: 0.9082 - recall: 0.9016 - val_auc: 0.9716 - val_binary_accuracy: 0.9164 - val_loss: 0.2811 - val_precision: 0.9360 - val_recall: 0.8457 - learning_rate: 3.2768e-04
Epoch 45/250
87/87 - 0s - 4ms/step - auc: 0.9607 - binary_accuracy: 0.9217 - loss: 0.2883 - precision: 0.9021 - recall: 0.8988 - val_auc: 0.9716 - val_binary_accuracy: 0.9175 - val_loss: 0.2809 - val_precision: 0.9362 - val_recall: 0.8485 - learning_rate: 3.2768e-04
Epoch 46/250
87/87 - 0s - 3ms/step - auc: 0.9634 - binary_accuracy: 0.9261 - loss: 0.2809 - precision: 0.9099 - recall: 0.9016 - val_auc: 0.9716 - val_binary_accuracy: 0.9175 - val_loss: 0.2837 - val_precision: 0.9388 - val_recall: 0.8457 - learning_rate: 3.2768e-04
Epoch 47/250
87/87 - 0s - 3ms/step - auc: 0.9624 - binary_accuracy: 0.9250 - loss: 0.2829 - precision: 0.9082 - recall: 0.9006 - val_auc: 0.9711 - val_binary_accuracy: 0.9131 - val_loss: 0.2874 - val_precision: 0.9381 - val_recall: 0.8347 - learning_rate: 3.2768e-04
Epoch 48/250
87/87 - 0s - 3ms/step - auc: 0.9635 - binary_accuracy: 0.9224 - loss: 0.2800 - precision: 0.9053 - recall: 0.8970 - val_auc: 0.9711 - val_binary_accuracy: 0.9142 - val_loss: 0.2862 - val_precision: 0.9383 - val_recall: 0.8375 - learning_rate: 3.2768e-04
Epoch 49/250
87/87 - 0s - 3ms/step - auc: 0.9628 - binary_accuracy: 0.9239 - loss: 0.2793 - precision: 0.9094 - recall: 0.8960 - val_auc: 0.9720 - val_binary_accuracy: 0.9197 - val_loss: 0.2755 - val_precision: 0.9366 - val_recall: 0.8540 - learning_rate: 2.6214e-04
Epoch 50/250
87/87 - 0s - 4ms/step - auc: 0.9619 - binary_accuracy: 0.9286 - loss: 0.2806 - precision: 0.9105 - recall: 0.9080 - val_auc: 0.9719 - val_binary_accuracy: 0.9186 - val_loss: 0.2772 - val_precision: 0.9364 - val_recall: 0.8512 - learning_rate: 2.6214e-04
Epoch 51/250
87/87 - 0s - 3ms/step - auc: 0.9659 - binary_accuracy: 0.9271 - loss: 0.2737 - precision: 0.9117 - recall: 0.9025 - val_auc: 0.9719 - val_binary_accuracy: 0.9175 - val_loss: 0.2773 - val_precision: 0.9362 - val_recall: 0.8485 - learning_rate: 2.6214e-04
Epoch 52/250
87/87 - 0s - 3ms/step - auc: 0.9622 - binary_accuracy: 0.9253 - loss: 0.2809 - precision: 0.9098 - recall: 0.8997 - val_auc: 0.9715 - val_binary_accuracy: 0.9175 - val_loss: 0.2770 - val_precision: 0.9362 - val_recall: 0.8485 - learning_rate: 2.6214e-04
Epoch 53/250
87/87 - 0s - 3ms/step - auc: 0.9618 - binary_accuracy: 0.9290 - loss: 0.2768 - precision: 0.9191 - recall: 0.8988 - val_auc: 0.9719 - val_binary_accuracy: 0.9186 - val_loss: 0.2726 - val_precision: 0.9364 - val_recall: 0.8512 - learning_rate: 2.0972e-04
Epoch 54/250
87/87 - 0s - 4ms/step - auc: 0.9622 - binary_accuracy: 0.9297 - loss: 0.2805 - precision: 0.9138 - recall: 0.9071 - val_auc: 0.9716 - val_binary_accuracy: 0.9186 - val_loss: 0.2733 - val_precision: 0.9364 - val_recall: 0.8512 - learning_rate: 2.0972e-04
Epoch 55/250
87/87 - 0s - 2ms/step - auc: 0.9669 - binary_accuracy: 0.9279 - loss: 0.2693 - precision: 0.9142 - recall: 0.9016 - val_auc: 0.9717 - val_binary_accuracy: 0.9197 - val_loss: 0.2723 - val_precision: 0.9366 - val_recall: 0.8540 - learning_rate: 2.0972e-04
Epoch 56/250
87/87 - 0s - 3ms/step - auc: 0.9617 - binary_accuracy: 0.9257 - loss: 0.2782 - precision: 0.9121 - recall: 0.8979 - val_auc: 0.9715 - val_binary_accuracy: 0.9207 - val_loss: 0.2711 - val_precision: 0.9367 - val_recall: 0.8567 - learning_rate: 2.0972e-04
Epoch 57/250
87/87 - 0s - 2ms/step - auc: 0.9620 - binary_accuracy: 0.9286 - loss: 0.2758 - precision: 0.9113 - recall: 0.9071 - val_auc: 0.9717 - val_binary_accuracy: 0.9197 - val_loss: 0.2715 - val_precision: 0.9366 - val_recall: 0.8540 - learning_rate: 2.0972e-04
Epoch 58/250
87/87 - 0s - 3ms/step - auc: 0.9643 - binary_accuracy: 0.9300 - loss: 0.2739 - precision: 0.9162 - recall: 0.9052 - val_auc: 0.9718 - val_binary_accuracy: 0.9229 - val_loss: 0.2709 - val_precision: 0.9371 - val_recall: 0.8623 - learning_rate: 2.0972e-04
Epoch 59/250
87/87 - 0s - 4ms/step - auc: 0.9644 - binary_accuracy: 0.9297 - loss: 0.2721 - precision: 0.9185 - recall: 0.9016 - val_auc: 0.9718 - val_binary_accuracy: 0.9207 - val_loss: 0.2717 - val_precision: 0.9367 - val_recall: 0.8567 - learning_rate: 2.0972e-04
Epoch 60/250
87/87 - 0s - 3ms/step - auc: 0.9669 - binary_accuracy: 0.9322 - loss: 0.2616 - precision: 0.9151 - recall: 0.9126 - val_auc: 0.9713 - val_binary_accuracy: 0.9197 - val_loss: 0.2722 - val_precision: 0.9366 - val_recall: 0.8540 - learning_rate: 2.0972e-04
Epoch 61/250
87/87 - 0s - 2ms/step - auc: 0.9645 - binary_accuracy: 0.9261 - loss: 0.2713 - precision: 0.9107 - recall: 0.9006 - val_auc: 0.9713 - val_binary_accuracy: 0.9207 - val_loss: 0.2720 - val_precision: 0.9367 - val_recall: 0.8567 - learning_rate: 2.0972e-04
Epoch 62/250
87/87 - 0s - 4ms/step - auc: 0.9646 - binary_accuracy: 0.9297 - loss: 0.2703 - precision: 0.9177 - recall: 0.9025 - val_auc: 0.9717 - val_binary_accuracy: 0.9240 - val_loss: 0.2700 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 63/250
87/87 - 0s - 3ms/step - auc: 0.9651 - binary_accuracy: 0.9308 - loss: 0.2676 - precision: 0.9171 - recall: 0.9062 - val_auc: 0.9720 - val_binary_accuracy: 0.9240 - val_loss: 0.2697 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 64/250
87/87 - 0s - 2ms/step - auc: 0.9647 - binary_accuracy: 0.9293 - loss: 0.2707 - precision: 0.9145 - recall: 0.9052 - val_auc: 0.9718 - val_binary_accuracy: 0.9240 - val_loss: 0.2682 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 65/250
87/87 - 0s - 3ms/step - auc: 0.9626 - binary_accuracy: 0.9293 - loss: 0.2754 - precision: 0.9145 - recall: 0.9052 - val_auc: 0.9718 - val_binary_accuracy: 0.9240 - val_loss: 0.2675 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 66/250
87/87 - 0s - 3ms/step - auc: 0.9670 - binary_accuracy: 0.9293 - loss: 0.2634 - precision: 0.9137 - recall: 0.9062 - val_auc: 0.9718 - val_binary_accuracy: 0.9251 - val_loss: 0.2666 - val_precision: 0.9375 - val_recall: 0.8678 - learning_rate: 1.6777e-04
Epoch 67/250
87/87 - 0s - 3ms/step - auc: 0.9640 - binary_accuracy: 0.9304 - loss: 0.2714 - precision: 0.9147 - recall: 0.9080 - val_auc: 0.9719 - val_binary_accuracy: 0.9251 - val_loss: 0.2669 - val_precision: 0.9375 - val_recall: 0.8678 - learning_rate: 1.6777e-04
Epoch 68/250
87/87 - 0s - 2ms/step - auc: 0.9649 - binary_accuracy: 0.9311 - loss: 0.2673 - precision: 0.9196 - recall: 0.9043 - val_auc: 0.9718 - val_binary_accuracy: 0.9240 - val_loss: 0.2675 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 69/250
87/87 - 0s - 4ms/step - auc: 0.9658 - binary_accuracy: 0.9315 - loss: 0.2656 - precision: 0.9196 - recall: 0.9052 - val_auc: 0.9718 - val_binary_accuracy: 0.9240 - val_loss: 0.2677 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.6777e-04
Epoch 70/250
87/87 - 0s - 4ms/step - auc: 0.9668 - binary_accuracy: 0.9326 - loss: 0.2647 - precision: 0.9246 - recall: 0.9025 - val_auc: 0.9718 - val_binary_accuracy: 0.9251 - val_loss: 0.2665 - val_precision: 0.9375 - val_recall: 0.8678 - learning_rate: 1.3422e-04
Epoch 71/250
87/87 - 0s - 2ms/step - auc: 0.9650 - binary_accuracy: 0.9286 - loss: 0.2701 - precision: 0.9167 - recall: 0.9006 - val_auc: 0.9718 - val_binary_accuracy: 0.9262 - val_loss: 0.2653 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 72/250
87/87 - 0s - 4ms/step - auc: 0.9670 - binary_accuracy: 0.9275 - loss: 0.2623 - precision: 0.9149 - recall: 0.8997 - val_auc: 0.9717 - val_binary_accuracy: 0.9262 - val_loss: 0.2652 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 73/250
87/87 - 0s - 3ms/step - auc: 0.9689 - binary_accuracy: 0.9297 - loss: 0.2586 - precision: 0.9161 - recall: 0.9043 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2659 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 74/250
87/87 - 0s - 2ms/step - auc: 0.9647 - binary_accuracy: 0.9337 - loss: 0.2656 - precision: 0.9209 - recall: 0.9098 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2659 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 75/250
87/87 - 0s - 2ms/step - auc: 0.9660 - binary_accuracy: 0.9333 - loss: 0.2622 - precision: 0.9224 - recall: 0.9071 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2650 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 76/250
87/87 - 0s - 2ms/step - auc: 0.9681 - binary_accuracy: 0.9344 - loss: 0.2581 - precision: 0.9234 - recall: 0.9089 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2646 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 77/250
87/87 - 0s - 2ms/step - auc: 0.9653 - binary_accuracy: 0.9308 - loss: 0.2654 - precision: 0.9187 - recall: 0.9043 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2631 - val_precision: 0.9379 - val_recall: 0.8733 - learning_rate: 1.3422e-04
Epoch 78/250
87/87 - 0s - 2ms/step - auc: 0.9656 - binary_accuracy: 0.9304 - loss: 0.2632 - precision: 0.9140 - recall: 0.9089 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2638 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.3422e-04
Epoch 79/250
87/87 - 0s - 3ms/step - auc: 0.9650 - binary_accuracy: 0.9308 - loss: 0.2668 - precision: 0.9164 - recall: 0.9071 - val_auc: 0.9720 - val_binary_accuracy: 0.9240 - val_loss: 0.2648 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.3422e-04
Epoch 80/250
87/87 - 0s - 3ms/step - auc: 0.9655 - binary_accuracy: 0.9322 - loss: 0.2639 - precision: 0.9198 - recall: 0.9071 - val_auc: 0.9719 - val_binary_accuracy: 0.9240 - val_loss: 0.2645 - val_precision: 0.9373 - val_recall: 0.8650 - learning_rate: 1.3422e-04
Epoch 81/250
87/87 - 0s - 4ms/step - auc: 0.9657 - binary_accuracy: 0.9322 - loss: 0.2631 - precision: 0.9213 - recall: 0.9052 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2616 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 1.0737e-04
Epoch 82/250
87/87 - 0s - 3ms/step - auc: 0.9682 - binary_accuracy: 0.9315 - loss: 0.2572 - precision: 0.9204 - recall: 0.9043 - val_auc: 0.9721 - val_binary_accuracy: 0.9273 - val_loss: 0.2621 - val_precision: 0.9379 - val_recall: 0.8733 - learning_rate: 1.0737e-04
Epoch 83/250
87/87 - 0s - 4ms/step - auc: 0.9676 - binary_accuracy: 0.9304 - loss: 0.2602 - precision: 0.9178 - recall: 0.9043 - val_auc: 0.9722 - val_binary_accuracy: 0.9262 - val_loss: 0.2619 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 1.0737e-04
Epoch 84/250
87/87 - 0s - 2ms/step - auc: 0.9676 - binary_accuracy: 0.9319 - loss: 0.2598 - precision: 0.9197 - recall: 0.9062 - val_auc: 0.9722 - val_binary_accuracy: 0.9251 - val_loss: 0.2636 - val_precision: 0.9375 - val_recall: 0.8678 - learning_rate: 1.0737e-04
Epoch 85/250
87/87 - 0s - 5ms/step - auc: 0.9697 - binary_accuracy: 0.9322 - loss: 0.2496 - precision: 0.9221 - recall: 0.9043 - val_auc: 0.9723 - val_binary_accuracy: 0.9262 - val_loss: 0.2617 - val_precision: 0.9377 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 86/250
87/87 - 0s - 2ms/step - auc: 0.9672 - binary_accuracy: 0.9319 - loss: 0.2602 - precision: 0.9189 - recall: 0.9071 - val_auc: 0.9721 - val_binary_accuracy: 0.9251 - val_loss: 0.2614 - val_precision: 0.9349 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 87/250
87/87 - 0s - 2ms/step - auc: 0.9691 - binary_accuracy: 0.9348 - loss: 0.2549 - precision: 0.9203 - recall: 0.9135 - val_auc: 0.9722 - val_binary_accuracy: 0.9251 - val_loss: 0.2614 - val_precision: 0.9349 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 88/250
87/87 - 0s - 4ms/step - auc: 0.9680 - binary_accuracy: 0.9344 - loss: 0.2572 - precision: 0.9226 - recall: 0.9098 - val_auc: 0.9722 - val_binary_accuracy: 0.9251 - val_loss: 0.2608 - val_precision: 0.9349 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 89/250
87/87 - 0s - 3ms/step - auc: 0.9681 - binary_accuracy: 0.9344 - loss: 0.2566 - precision: 0.9210 - recall: 0.9117 - val_auc: 0.9722 - val_binary_accuracy: 0.9251 - val_loss: 0.2614 - val_precision: 0.9349 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 90/250
87/87 - 0s - 3ms/step - auc: 0.9656 - binary_accuracy: 0.9333 - loss: 0.2623 - precision: 0.9224 - recall: 0.9071 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2597 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 8.5899e-05
Epoch 91/250
87/87 - 0s - 4ms/step - auc: 0.9693 - binary_accuracy: 0.9333 - loss: 0.2538 - precision: 0.9255 - recall: 0.9034 - val_auc: 0.9722 - val_binary_accuracy: 0.9251 - val_loss: 0.2608 - val_precision: 0.9349 - val_recall: 0.8705 - learning_rate: 8.5899e-05
Epoch 92/250
87/87 - 0s - 2ms/step - auc: 0.9660 - binary_accuracy: 0.9355 - loss: 0.2640 - precision: 0.9260 - recall: 0.9089 - val_auc: 0.9722 - val_binary_accuracy: 0.9262 - val_loss: 0.2603 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 8.5899e-05
Epoch 93/250
87/87 - 0s - 4ms/step - auc: 0.9689 - binary_accuracy: 0.9355 - loss: 0.2526 - precision: 0.9276 - recall: 0.9071 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2607 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 8.5899e-05
Epoch 94/250
87/87 - 0s - 4ms/step - auc: 0.9697 - binary_accuracy: 0.9333 - loss: 0.2509 - precision: 0.9231 - recall: 0.9062 - val_auc: 0.9720 - val_binary_accuracy: 0.9262 - val_loss: 0.2600 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 6.8719e-05
Epoch 95/250
87/87 - 0s - 3ms/step - auc: 0.9658 - binary_accuracy: 0.9337 - loss: 0.2620 - precision: 0.9209 - recall: 0.9098 - val_auc: 0.9722 - val_binary_accuracy: 0.9262 - val_loss: 0.2591 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 6.8719e-05
Epoch 96/250
87/87 - 0s - 2ms/step - auc: 0.9706 - binary_accuracy: 0.9377 - loss: 0.2502 - precision: 0.9232 - recall: 0.9181 - val_auc: 0.9721 - val_binary_accuracy: 0.9262 - val_loss: 0.2597 - val_precision: 0.9351 - val_recall: 0.8733 - learning_rate: 6.8719e-05
Epoch 97/250
87/87 - 0s - 4ms/step - auc: 0.9672 - binary_accuracy: 0.9322 - loss: 0.2551 - precision: 0.9190 - recall: 0.9080 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2589 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 6.8719e-05
Epoch 98/250
87/87 - 0s - 3ms/step - auc: 0.9681 - binary_accuracy: 0.9333 - loss: 0.2560 - precision: 0.9216 - recall: 0.9080 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2584 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 6.8719e-05
Epoch 99/250
87/87 - 0s - 2ms/step - auc: 0.9680 - binary_accuracy: 0.9344 - loss: 0.2533 - precision: 0.9210 - recall: 0.9117 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2587 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 6.8719e-05
Epoch 100/250
87/87 - 0s - 2ms/step - auc: 0.9728 - binary_accuracy: 0.9355 - loss: 0.2420 - precision: 0.9189 - recall: 0.9172 - val_auc: 0.9721 - val_binary_accuracy: 0.9273 - val_loss: 0.2600 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 6.8719e-05
Epoch 101/250
87/87 - 0s - 2ms/step - auc: 0.9699 - binary_accuracy: 0.9351 - loss: 0.2499 - precision: 0.9267 - recall: 0.9071 - val_auc: 0.9721 - val_binary_accuracy: 0.9273 - val_loss: 0.2602 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 6.8719e-05
Epoch 102/250
87/87 - 0s - 2ms/step - auc: 0.9728 - binary_accuracy: 0.9358 - loss: 0.2441 - precision: 0.9268 - recall: 0.9089 - val_auc: 0.9723 - val_binary_accuracy: 0.9273 - val_loss: 0.2592 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 5.4976e-05
Epoch 103/250
87/87 - 0s - 4ms/step - auc: 0.9706 - binary_accuracy: 0.9369 - loss: 0.2456 - precision: 0.9270 - recall: 0.9117 - val_auc: 0.9722 - val_binary_accuracy: 0.9273 - val_loss: 0.2584 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 5.4976e-05
Epoch 104/250
87/87 - 0s - 3ms/step - auc: 0.9698 - binary_accuracy: 0.9387 - loss: 0.2505 - precision: 0.9298 - recall: 0.9135 - val_auc: 0.9723 - val_binary_accuracy: 0.9273 - val_loss: 0.2588 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 5.4976e-05
Epoch 105/250
87/87 - 0s - 4ms/step - auc: 0.9685 - binary_accuracy: 0.9366 - loss: 0.2510 - precision: 0.9246 - recall: 0.9135 - val_auc: 0.9723 - val_binary_accuracy: 0.9273 - val_loss: 0.2588 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 106/250
87/87 - 0s - 2ms/step - auc: 0.9724 - binary_accuracy: 0.9366 - loss: 0.2403 - precision: 0.9262 - recall: 0.9117 - val_auc: 0.9724 - val_binary_accuracy: 0.9273 - val_loss: 0.2590 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 107/250
87/87 - 0s - 3ms/step - auc: 0.9652 - binary_accuracy: 0.9358 - loss: 0.2612 - precision: 0.9244 - recall: 0.9117 - val_auc: 0.9722 - val_binary_accuracy: 0.9273 - val_loss: 0.2582 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 108/250
87/87 - 0s - 3ms/step - auc: 0.9688 - binary_accuracy: 0.9348 - loss: 0.2517 - precision: 0.9250 - recall: 0.9080 - val_auc: 0.9722 - val_binary_accuracy: 0.9273 - val_loss: 0.2579 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 109/250
87/87 - 0s - 5ms/step - auc: 0.9680 - binary_accuracy: 0.9337 - loss: 0.2565 - precision: 0.9216 - recall: 0.9089 - val_auc: 0.9721 - val_binary_accuracy: 0.9273 - val_loss: 0.2579 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 110/250
87/87 - 0s - 4ms/step - auc: 0.9683 - binary_accuracy: 0.9337 - loss: 0.2547 - precision: 0.9280 - recall: 0.9016 - val_auc: 0.9722 - val_binary_accuracy: 0.9273 - val_loss: 0.2578 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 111/250
87/87 - 0s - 2ms/step - auc: 0.9692 - binary_accuracy: 0.9402 - loss: 0.2502 - precision: 0.9292 - recall: 0.9181 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2581 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 112/250
87/87 - 0s - 3ms/step - auc: 0.9722 - binary_accuracy: 0.9358 - loss: 0.2455 - precision: 0.9292 - recall: 0.9062 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2580 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 113/250
87/87 - 0s - 2ms/step - auc: 0.9712 - binary_accuracy: 0.9351 - loss: 0.2470 - precision: 0.9283 - recall: 0.9052 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2579 - val_precision: 0.9353 - val_recall: 0.8760 - learning_rate: 4.3980e-05
Epoch 114/250
87/87 - 0s - 3ms/step - auc: 0.9717 - binary_accuracy: 0.9391 - loss: 0.2426 - precision: 0.9290 - recall: 0.9154 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2577 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 3.5184e-05
Epoch 115/250
87/87 - 0s - 3ms/step - auc: 0.9726 - binary_accuracy: 0.9395 - loss: 0.2404 - precision: 0.9283 - recall: 0.9172 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2580 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 3.5184e-05
Epoch 116/250
87/87 - 0s - 4ms/step - auc: 0.9694 - binary_accuracy: 0.9369 - loss: 0.2488 - precision: 0.9270 - recall: 0.9117 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2578 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 3.5184e-05
Epoch 117/250
87/87 - 0s - 3ms/step - auc: 0.9698 - binary_accuracy: 0.9344 - loss: 0.2481 - precision: 0.9218 - recall: 0.9108 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2578 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.8147e-05
Epoch 118/250
87/87 - 0s - 3ms/step - auc: 0.9696 - binary_accuracy: 0.9351 - loss: 0.2486 - precision: 0.9267 - recall: 0.9071 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2571 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.8147e-05
Epoch 119/250
87/87 - 0s - 3ms/step - auc: 0.9715 - binary_accuracy: 0.9369 - loss: 0.2422 - precision: 0.9270 - recall: 0.9117 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2572 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.8147e-05
Epoch 120/250
87/87 - 0s - 3ms/step - auc: 0.9703 - binary_accuracy: 0.9387 - loss: 0.2464 - precision: 0.9322 - recall: 0.9108 - val_auc: 0.9720 - val_binary_accuracy: 0.9273 - val_loss: 0.2573 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.8147e-05
Epoch 121/250
87/87 - 0s - 4ms/step - auc: 0.9711 - binary_accuracy: 0.9358 - loss: 0.2450 - precision: 0.9252 - recall: 0.9108 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2573 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.8147e-05
Epoch 122/250
87/87 - 0s - 4ms/step - auc: 0.9698 - binary_accuracy: 0.9369 - loss: 0.2499 - precision: 0.9278 - recall: 0.9108 - val_auc: 0.9718 - val_binary_accuracy: 0.9273 - val_loss: 0.2574 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.2518e-05
Epoch 123/250
87/87 - 0s - 4ms/step - auc: 0.9706 - binary_accuracy: 0.9373 - loss: 0.2463 - precision: 0.9279 - recall: 0.9117 - val_auc: 0.9719 - val_binary_accuracy: 0.9273 - val_loss: 0.2574 - val_precision: 0.9327 - val_recall: 0.8788 - learning_rate: 2.2518e-05

Details on the trained model:

summary(mlp)
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                      ┃ Output Shape             ┃       Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ dense (Dense)                     │ (None, 128)              │         3,584 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout (Dropout)                 │ (None, 128)              │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_1 (Dense)                   │ (None, 64)               │         8,256 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_1 (Dropout)               │ (None, 64)               │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_2 (Dense)                   │ (None, 32)               │         2,080 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_2 (Dropout)               │ (None, 32)               │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_3 (Dense)                   │ (None, 16)               │           528 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dropout_3 (Dropout)               │ (None, 16)               │             0 │
├───────────────────────────────────┼──────────────────────────┼───────────────┤
│ dense_4 (Dense)                   │ (None, 1)                │            17 │
└───────────────────────────────────┴──────────────────────────┴───────────────┘
 Total params: 43,397 (169.52 KB)
 Trainable params: 14,465 (56.50 KB)
 Non-trainable params: 0 (0.00 B)
 Optimizer params: 28,932 (113.02 KB)

Visually:

mlp |> plot(show_shapes = TRUE, show_trainable = TRUE)

Looking at the training progression:

training_prog <- 
  history |> 
  as.data.frame() |> 
  tibble() |>
  pivot_wider(values_from = "value", names_from = "metric") |> 
  drop_na(loss)

Loss curves:

training_prog |> 
  ggplot(aes(x = epoch, y = loss, color = data)) +
  geom_line() +
  theme_minimal() +
  labs(
    title = "Training curves",
    subtitle = "Binary cross-entropy loss on training and validation sets, over epochs",
    x = "Epochs",
    y = "Loss",
    color = "Data"
  )

Validation metrics:

training_prog |> 
  select(-c(learning_rate, loss)) |> 
  pivot_longer(-c(epoch, data), names_to = "metric", values_to = "value") |> 
  ggplot(aes(x = epoch, y = value, color = data)) +
  geom_line() +
  facet_wrap(~metric) +
  theme_minimal() +
  labs(
    title = "Training improvements",
    subtitle = "Development of metrics over epochs, validation set",
    x = "Epochs",
    y = "",
    color = "Data"
  )

Collecting final metrics for training set:

class_metrics <- metric_set(accuracy, precision, f_meas)

mlp_metrics_train <- 
  mlp$predict(X_train) |> 
  round() |> 
  as.vector() |> 
  tibble(mlp_pred = _) |> 
  bind_cols(train) |> 
  mutate(mlp_pred = factor(if_else(mlp_pred == 1, "spam", "no spam"), levels = c("spam", "no spam"))) |> 
  class_metrics(truth = spam, estimate = mlp_pred) |> 
  select(-.estimator) |> 
  pivot_wider(names_from = ".metric", values_from = ".estimate") |> 
  mutate(name = "Neural Network")

 1/87 ━━━━━━━━━━━━━━━━━━━━ 11s 131ms/step
76/87 ━━━━━━━━━━━━━━━━━━━━ 0s 673us/step 
87/87 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step  
87/87 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step

Evaluating the model on the test set:

mlp_preds <- 
  mlp$predict(X_test) |> 
  round() |> 
  as.vector() |> 
  tibble(mlp_pred = _) |> 
  bind_cols(test) |> 
  mutate(mlp_pred = factor(if_else(mlp_pred == 1, "spam", "no spam"), levels = c("spam", "no spam")))

 1/29 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step
29/29 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step 
29/29 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step

Metrics:

mlp_preds |> 
  class_metrics(truth = spam, estimate = mlp_pred)
# A tibble: 3 × 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 accuracy  binary         0.939
2 precision binary         0.932
3 f_meas    binary         0.922

Confusion matrix:

mlp_preds |> 
  conf_mat(truth = spam, estimate = mlp_pred) |> 
  autoplot(type = "heatmap")

Competitors

10-fold cross validation:

set.seed(42)

folds <- vfold_cv(train, v = 10, strata = "spam")

To select a proper competitor, I tune three “traditional” models:

  • Penalized logistic regression
  • Random Forest
  • Naive Bayes
set.seed(42)

log_spec <- logistic_reg(
  mode = "classification",
  engine = "glmnet",
  penalty = tune(),
  mixture = 1 # pure L1 regularization
)

log_grid <- tibble(penalty = 10^seq(-5, -1, length.out = 50))

rf_spec <- rand_forest(
  mode = "classification",
  mtry = tune(),
  trees = tune(),
  min_n = tune()
) |> set_engine("ranger", importance = "impurity") # variable importance

rf_grid <- expand_grid(
  mtry = c(4, 8, 12),
  trees = c(100, 500, 1000),
  min_n = c(5, 10, 20)
)

# For naive bayes, we disable kernel density estimation. This yields
# Gaussian naive bayes, so we assume normally distributed features
nb_spec <- naive_Bayes(
  mode = "classification",
  smoothness = tune()
) |> set_engine("naivebayes", usekernel = FALSE)

# just using raw probabilities (smoothing has no effect for Gaussian bayes
# with only numerical features I believe), but also cross-validating
nb_grid <- tibble(smoothness = 0)

Glueing together & tuning:

set.seed(42)

models <- tribble(
  ~name,                 ~spec,    ~grid,
  "Logistic Regression", log_spec, log_grid,
  "Random Forest",       rf_spec,  rf_grid,
  "Naive Bayes",         nb_spec,  nb_grid
)

models <- 
  models |> 
  mutate(
    workflow = map(spec, \(m) workflow() |> add_recipe(spam_rec) |> add_model(m)),
    tuning_res = map2(workflow, grid, function(wf, g) {
      tune_grid(
        wf, 
        resamples = folds,
        grid = g,
        metrics = metric_set(accuracy, precision, f_meas),
        control = control_grid(verbose = TRUE, save_pred = TRUE)
      )
    }),
    metrics = map(tuning_res, collect_metrics)
  )

Tuning results

First, adding 95% confidence intervals to metric estimates (estimated using the results collected across folds):

models <- 
  models |> 
  mutate(
    metrics = map(metrics, function(m) { 
      m |> mutate(lower = mean - 1.96 * std_err, upper = mean + 1.96 * std_err)
    })
  )

Logistic Regression

Here, we tuned the penalty parameter:

models |> 
  filter(name == "Logistic Regression") |> 
  select(name, metrics) |> 
  unnest(metrics) |> 
  ggplot(aes(x = penalty, y = mean, fill = .metric)) +
  geom_ribbon(aes(ymin = lower, ymax = upper), alpha = 0.2) +
  geom_point(aes(color = .metric)) +
  geom_line(aes(color = .metric)) +
  facet_wrap(~.metric, scales = "free_y") +
  scale_x_log10() +
  theme_minimal() +
  theme(legend.position = "none") +
  labs(
    title = "Logistic Regression", 
    subtitle = "Tuning Results", 
    x = expression(lambda),
    y = "",
    caption = "Shaded area indicates 95% confidence interval.\nLogarithmic X-axis."
  )

Random Forest

models |> 
  filter(name == "Random Forest") |> 
  select(name, metrics) |> 
  unnest(metrics) |> 
  mutate(min_n = factor(min_n, ordered = TRUE)) |> 
  ggplot(aes(x = mtry, y = mean, group = min_n)) +
  geom_ribbon(aes(ymin = lower, ymax = upper, fill = min_n), alpha = .1) +
  geom_point(aes(color = min_n, )) +
  geom_line(aes(color = min_n, )) +
  facet_grid(vars(.metric), vars(trees), scale = "free_y") +
  theme_minimal() +
  labs(
    title = "Random Forest",
    subtitle = "Tuning Results, by number of trees",
    x = "Number of randomly sampled features",
    y = "",
    fill = "Min. N in \nnode for split",
    color = "Min. N in \nnode for split",
    caption = "Shaded area indicates 95% confidence interval."
  )

Naive Bayes

models |> 
  filter(name == "Naive Bayes") |> 
  select(name, metrics) |> 
  unnest(metrics) |> 
  ggplot(aes(x = smoothness, y = mean, color = .metric)) +
  geom_point() +
  geom_linerange(aes(ymin = lower, ymax = upper)) +
  facet_wrap(~.metric, nrow = 1) +
  theme_minimal() +
  theme(legend.position = "none", axis.text.x = element_blank()) +
  labs(
    title = "Naive Bayes",
    subtitle = "Metric estimates across 10 folds, with 95% confidence interval",
    x = "",
    y = ""
  )

Uncertainty estimation for best competitors

After tuning, selecting the best models by (1) accuracy, and (2) precision, and showing their uncertainty across folds:

best_train <- 
  models |> 
  mutate(
    # best models by accuracy & precision:
    best_acc = map(tuning_res, \(res) select_best(res, metric = "accuracy")),
    best_prec = map(tuning_res, \(res) select_best(res, metric = "precision")),
    # metrics only for best models:
    across(c(best_acc, best_prec), function(params) {
      map2(tuning_res, params, function(res, p) {
        res |> 
          collect_metrics() |> 
          filter(.config == p$.config)
      })
    }, .names = "{col}_train"),
    # add 95% confidence interval:
    across(ends_with("_train"), function(metrics) {
      map(metrics, function(d) {
        d |> 
          mutate(lower = mean - 1.96 * std_err, upper = mean + 1.96 * std_err)
      })
    })
  ) |> 
  select(name, starts_with("best"))

Accuracy:

best_train |> 
  select(name, best_acc_train) |> 
  unnest(best_acc_train) |> 
  filter(.metric == "accuracy") |> 
  ggplot(aes(x = name, y = mean, color = name)) +
  geom_hline(yintercept = mlp_metrics_train$accuracy, lty = "dashed", color = "grey") +
  geom_point() +
  geom_linerange(aes(ymin = lower, ymax = upper)) +
  theme_minimal() +
  labs(
    title = "Uncertainty of accuracy estimates",
    subtitle = "Competitor models, training data\nEstimated on 10 folds",
    x = "",
    y = "Accuracy",
    caption = "Bars indicate 95% confidence interval"
  ) +
  theme(legend.position = "none") +
  annotate(
    "text", 
    x = 0.67, 
    y = mlp_metrics_train$accuracy + 0.002, 
    label = "Neural Network",
    color = "grey50",
    size = 4
  )

Precision (classifying emails as spam, false positives - i.e. mistakenly labeling “ham” emails as spam - are more costly):

best_train |> 
  select(name, best_prec_train) |> 
  unnest(best_prec_train) |> 
  filter(.metric == "precision") |> 
  ggplot(aes(x = name, y = mean, color = name)) +
  geom_hline(yintercept = mlp_metrics_train$accuracy, lty = "dashed", color = "grey") +
  geom_point() +
  geom_linerange(aes(ymin = lower, ymax = upper)) +
  theme_minimal() +
  labs(
    title = "Uncertainty of precision estimates",
    subtitle = "Competitor models, training data\nEstimated on 10 folds",
    x = "",
    y = "Precision",
    caption = "Bars indicate 95% confidence interval"
  ) +
  theme(legend.position = "none") +
  annotate(
    "text", 
    x = 0.67, 
    y = mlp_metrics_train$accuracy + 0.002, 
    label = "Neural Network",
    color = "grey50",
    size = 4
  )

Random Forest vs. Neural Network on test set

rf_res <- models |> filter(name == "Random Forest")
rf_tuning_res <- rf_res |> pull(tuning_res) |> pluck(1)
rf_wf <- rf_res |> pull(workflow) |> pluck(1)

Considering the best random forest model by precision (most important metric here):

rf_fit_prec <- 
  rf_wf |> 
  finalize_workflow(select_best(rf_tuning_res, metric = "precision")) |>
  fit(train)

Test predictions for both (plus predicted class probability):

nn_preds <- 
  mlp |> 
  predict(X_test) |> 
  as.vector() |> 
  tibble(.pred_spam = _) |> 
  mutate(
    .pred_no_spam = 1 - .pred_spam, 
    .pred_class = round(.pred_spam),
    .pred_class = factor(
      if_else(.pred_class == 1, "spam", "no spam"),
      ordered = TRUE,
      levels = c("spam", "no spam")
    ),
    model = "Neural Network"
  ) |> 
  bind_cols(test |> select(actual = spam))
29/29 - 0s - 1ms/step
rf_preds <- 
  rf_fit_prec |> 
  predict(test) |> 
  bind_cols(rf_fit_prec |> predict(test, type = "prob")) |>
  rename(.pred_no_spam = `.pred_no spam`) |> 
  mutate(
    model = "Random Forest",
    .pred_class = factor(.pred_class, ordered = TRUE, levels = c("spam", "no spam"))
  ) |> 
  bind_cols(test |> select(actual = spam))

test_preds <- bind_rows(nn_preds, rf_preds)

Metrics:

# Some markdown magic:
mark_best <- function(x) {
  map_chr(x, function(val) {
    if (val == max(x))
      return(paste0("**", as.character(round(val, 3)), "**"))
    as.character(round(val, 3))
  })
}

test_preds |> 
  group_by(model) |> 
  nest(-model) |> 
  mutate(
    metrics = map(data, \(preds) {
      preds |> 
        class_metrics(truth = actual, estimate = .pred_class) |> 
        select(-.estimator) |> 
        pivot_wider(names_from = ".metric", values_from = ".estimate")
    })
  ) |> 
  select(model, metrics) |> 
  unnest(metrics) |> 
  rename(f1 = f_meas) |> 
  ungroup() |> 
  select(model, precision, everything()) |> 
  mutate(across(-model, mark_best)) |> 
  rename_with(stringr::str_to_title) |> 
  knitr::kable()
Model Precision Accuracy F1
Neural Network 0.932 0.939 0.922
Random Forest 0.935 0.942 0.926

Looking at model confidence. Graphically, we can see that the neural network is more confident in its correct predictions, but also overconfident in its wrong predictions:

confidence <- 
 test_preds |> 
  mutate(
    confidence = if_else(.pred_class == "spam", .pred_spam, .pred_no_spam),
    correct = if_else(actual == .pred_class, "correct", "incorrect")
  )

confidence |> 
  ggplot(aes(x = correct, y = confidence, fill = model, color = model)) +
  geom_hline(yintercept = c(.5, 1), lty = "dotted", color = "grey50") +
  geom_boxplot(position = position_dodge(width = 0.2), width = .1, outliers = FALSE, alpha = .5) +
  theme_minimal() +
  labs(
    title = "Confidence in predictions",
    subtitle = "By correct/incorrect classification",
    x = "",
    y = "Predicted class probability",
    fill = "Model",
    color = "Model"
  )

Other way of looking at it (making both & then deciding later for report):

confidence |> 
  ggplot(aes(x = confidence, color = model, fill = model)) +
  geom_density(alpha = .34) +
  facet_wrap(~correct, nrow = 2, scale = "free_y") +
  theme_minimal()  +
  labs(
    title = "Confidence in predictions",
    subtitle = "By correct/incorrect classification",
    x = "Predicted class probability",
    y = "Density",
    fill = "Model",
    color = "Model"
  ) +
  theme(aspect.ratio = .5)

ROC curves:

test_preds |> 
  group_by(model) |>
  roc_curve(truth = actual, .pred_spam) |>
  ungroup() |> 
  ggplot(aes(x = 1 - specificity, y = sensitivity, color = model)) +
  geom_line() + 
  geom_abline(linetype = "dotted", color = "grey50") +
  theme_minimal() +
  labs(
    title = "ROC Curves",
    subtitle = "Neural Network & Random Forest, Test set",
    x = "1 - Specificity",
    y = "Sensitivity",
    color = "Model"
  ) +
  theme(aspect.ratio = 1) # square, easier to tell what's going on

RF variable importance:

rf_fit_prec |> 
  extract_fit_parsnip() |> 
  vip::vi() |> 
  slice(1:10) |> 
  ggplot(aes(x = Importance, y = forcats::fct_reorder(Variable, Importance), color = Variable)) +
  geom_segment(aes(xend = 0, yend = Variable), size = 2, alpha = 0.5) +
  geom_point(size = 4) +
  theme_minimal() +
  labs(
    title = "Variable Importance",
    subtitle = "Random Forest",
    x = "Importance (Impurity)",
    y = "Feature"
  ) +
  theme(
    legend.position = "none",
    axis.text.y = element_text(size = 12)
  )

Bootstrapped confidence intervals

Given we cannot use tune::int_pctl() with the neural network, I am just doing both manually.

keras3::set_random_seed(42)

test_boot <- 
  bootstraps(test, times = 500, strata = "spam") |> 
  # no splits, we only want to make predictions:
  mutate(data = map(splits, analysis)) |> 
  select(id, data)

nn_boot <- 
  test_boot |> 
  mutate(
    metrics_dnn = map(data, function(sample) {
      # The model is standalone, not a "workflow",
      # so we need to send the data through the prep pipeline
      # manually & then convert to matrix format
      X <- 
        spam_rec |> 
        prep() |> 
        bake(new_data = sample) |> 
        select(-spam) |> 
        as.matrix() |> 
        unname()
      
      mlp$predict(X, verbose = 0) |> 
        as.vector() |> 
        round() |> 
        tibble(mlp_pred = _) |> 
        bind_cols(sample) |> 
        mutate(
          mlp_pred = factor(
            if_else(mlp_pred == 1, "spam", "no spam"), levels = c("spam", "no spam")
          )
        ) |> 
        class_metrics(truth = spam, estimate = mlp_pred) |> 
        select(-.estimator) |> 
        pivot_wider(names_from = ".metric", values_from = ".estimate") |> 
        mutate(name = "Neural Network")
    })
  ) |> 
  unnest(metrics_dnn)

rf_boot <- 
  test_boot |> 
  mutate(
    metrics_rf = map(data, function(sample) {
      # This is a workflow object containing the preprocessing pipeline
      # and model that can just take any data directly:
      rf_fit_prec |> 
        augment(new_data = sample) |> 
        class_metrics(truth = spam, estimate = .pred_class) |> 
        select(-.estimator) |> 
        pivot_wider(names_from = ".metric", values_from = ".estimate") |> 
        mutate(name = "Random Forest")
    })
  ) |> 
  unnest(metrics_rf)

Evaluating:

fns <- list(
  mean = mean,
  # 95% confidence intervals:
  lower = \(x) mean(x) - 1.96 * (sd(x) / sqrt(500)), # 500 = n
  upper = \(x) mean(x) + 1.96 * (sd(x) / sqrt(500))
)

boot_res <- 
  nn_boot |> 
  select(-c(id, data)) |> 
  bind_rows(rf_boot |> select(-c(id, data))) |> 
  pivot_longer(-name, names_to = "metric", values_to = "estimate") |> 
  group_by(name, metric) |> 
  summarise(across(estimate, fns, .names = "{fn}")) |> 
  ungroup()

boot_res
# A tibble: 6 × 5
  name           metric     mean lower upper
  <chr>          <chr>     <dbl> <dbl> <dbl>
1 Neural Network accuracy  0.939 0.939 0.940
2 Neural Network f_meas    0.922 0.921 0.923
3 Neural Network precision 0.933 0.932 0.934
4 Random Forest  accuracy  0.943 0.942 0.943
5 Random Forest  f_meas    0.926 0.926 0.927
6 Random Forest  precision 0.936 0.935 0.937

Inspecting graphically:

boot_res |> 
  mutate(metric = if_else(metric == "f_meas", "F1", metric) |> stringr::str_to_title()) |> 
  ggplot(aes(x = metric, y = mean, color = name)) +
  geom_point(position = position_dodge(0.1)) +
  geom_errorbar(aes(ymin = lower, ymax = upper), position = position_dodge(0.1), width = .05) +
  theme_minimal() +
  labs(
    title = "Performance on test set",
    subtitle = "Bootstrapped 95% confidence intervals",
    x = "Metric",
    y = "Estimate",
    color = ""
  )